import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter

class CosineLinear(nn.Module):
    def __init__(self, hidden_dim, output_dim, sigma=True):
        super(CosineLinear, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.weight = Parameter(torch.Tensor(output_dim, hidden_dim))
        if sigma:
            self.sigma = Parameter(torch.Tensor(1))
        else:
            self.register_parameter('sigma', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.sigma is not None:
            self.sigma.data.fill_(1) #for initializaiton of sigma

    def forward(self, input):

        out = F.linear(F.normalize(input, p=2,dim=1), \
                F.normalize(self.weight, p=2, dim=1))
        
        if self.sigma is not None:
            out = self.sigma * out

        return out

class SplitCosineLinear(nn.Module):
    #consists of two fc layers and concatenate their outputs
    def __init__(self, hidden_dim, old_output_dim, new_output_dim, sigma=True):
        super(SplitCosineLinear, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = old_output_dim + new_output_dim
        self.fc0 = CosineLinear(hidden_dim, old_output_dim, False)
        self.fc1 = CosineLinear(hidden_dim, new_output_dim, False)
        if sigma:
            self.sigma = Parameter(torch.Tensor(1))
            self.sigma.data.fill_(1)
        else:
            self.register_parameter('sigma', None)

    def reset_parameters(self):
        self.fc0.reset_parameters()
        self.fc1.reset_parameters()

    def forward(self, x):
        out0 = self.fc0(x)
        out1 = self.fc1(x)
        out = torch.cat((out0, out1), dim=-1)  # concatenate along the channel
        if self.sigma is not None:
            out = self.sigma * out
        return out
